package org.zjvis.datascience.common.algo;

import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.zjvis.datascience.common.constant.SqlTemplate;
import org.zjvis.datascience.common.enums.AlgEnum;
import org.zjvis.datascience.common.enums.SubTypeEnum;
import org.zjvis.datascience.common.sql.SqlHelper;
import org.zjvis.datascience.common.util.ToolUtil;
import org.zjvis.datascience.common.vo.TaskVO;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

/**
 * @description Linear Regression 线性回归算子模板类 [已废弃， 模型模块可替代]
 * @date 2021-12-24
 */
@Deprecated
public class LinearRegressionAlg extends BaseAlg {

    private final static Logger logger = LoggerFactory.getLogger("LinearRegressionAlg");

    private static String TPL_FILENAME = "template/algo/linear_regression.json";

    private static String SQL_TPL_MADLIB = "SELECT * FROM \"%s\".\"linear_regression\"('%s', '%s', '%s', '%s', '%s', '%s')";
    private static String SQL_TPL_MADLIB_SAMPLE = "SELECT * FROM \"%s\".\"linear_regression\"('CREATE VIEW %s AS SELECT * from %s where \"%s\" <= %s', '%s', '%s', '%s', '%s', '%s')";
    private static String SQL_TPL_SPARK = "linear -s %s -f %s -t %s -m %d -label %s -enet %s -reg %s -tol %s -uk %d -idcol %s";

    public void initTemplate(JSONObject data) {
        JSONArray jsonArray = getTemplateParamList(TPL_FILENAME);
        data.put("setParams", jsonArray);
        baseInitTemplate(data);
        JSONArray validate = new JSONArray();
        validate.add("feature_cols,number");
        data.put("validate", validate);
    }

    public LinearRegressionAlg() {
        super(AlgEnum.LINEARREGRE.name(), SubTypeEnum.REGRESSION.getVal(),
                SubTypeEnum.REGRESSION.getDesc());
        this.maxParentNumber = 1;
    }

    public String getLinearRegressionSql(String sourceTable, String modelTale, String resultTable,
                                         String groundTruth,
                                         String featureCols, String groupingCols, long timeStamp, float elasticNet,
                                         float reg, float tol, int maxIter, String sampleTable) {
        if (StringUtils.isEmpty(featureCols)) {
            return StringUtils.EMPTY;
        }
        if (StringUtils.isNotEmpty(sampleTable)) {
            // 采样
            return String.format(SQL_TPL_MADLIB_SAMPLE, SqlTemplate.SCHEMA, sampleTable,
                    sourceTable, ID_COL, SAMPLE_NUMBER, modelTale, resultTable, groundTruth,
                    featureCols, groupingCols);
        } else {
            // 全量, 根据配置
            if (getEngine().isMadlib()) {
                return String
                        .format(SQL_TPL_MADLIB, SqlTemplate.SCHEMA, sourceTable, modelTale, resultTable,
                                groundTruth, featureCols, groupingCols);
            } else if (getEngine().isSpark()) {
                return String.format(SQL_TPL_SPARK, sourceTable, featureCols, resultTable, maxIter,
                        groundTruth, elasticNet, reg, tol, timeStamp, ID_COL);
            }
        }
        return StringUtils.EMPTY;
    }

    public String initSql(JSONObject json, List<SqlHelper> sqlHelpers, long timeStamp,
                          String engineName) {
        this.engineName = engineName;
        String sourceTable = json.getString("source_table");
        sourceTable = ToolUtil.alignTableName(sourceTable, timeStamp);
        String outTable = json.getString("out_table_rename");
        String modelTable = ToolUtil.alignTableName(json.getString("model_table"), timeStamp);
        String resultTable = ToolUtil.alignTableName(outTable, timeStamp);
        JSONArray features = json.getJSONArray("feature_cols");
        String featureCols = this.getFeatureColsStr(features);
        // JSONArray groupings = json.getJSONArray("grouping_cols");
        JSONArray groupings = new JSONArray();
        String groundTruth = json.getString("ground_truth");
        String groupingCols;
        if (groupings.size() == 0) {
            groupingCols = "NULL";
        } else {
            groupingCols = this.getFeatureColsStr(groupings);
        }
        int maxIter = json.getInteger("max_iter");
        float elasticNet = json.getFloat("elastic_net");
        float reg = json.getFloat("reg");
        float tol = json.getFloat("tol");
        String sampleTable = "";
        if (!json.containsKey("isSample") || json.getString("isSample").equals("SUCCESS") || json
                .getString("isSample").equals("FAIL")) {
            sampleTable = resultTable.replace("solid_", "view_");
            json.put("isSample", "CREATE");
        }
        return this
                .getLinearRegressionSql(sourceTable, modelTable, resultTable, groundTruth, featureCols,
                        groupingCols, timeStamp, elasticNet, reg, tol, maxIter, sampleTable);
    }

    public void defineOutput(TaskVO vo) {
        JSONObject jsonObject = vo.getData();
        String outTablePrefix = jsonObject.getString("out_table");
        String tableName = String
                .format(SqlTemplate.OUT_TABLE_NAME, outTablePrefix, vo.getPipelineId(), vo.getId());
        jsonObject.put("out_table_rename", tableName);
        JSONArray input = jsonObject.getJSONArray("input");
        if (input == null || input.size() == 0) {
            logger.warn("input is empty");
            return;
        }
        this.supplementForSelector(jsonObject, TPL_FILENAME, 2, vo, "number");
        this.checkBoxSelectFilter(jsonObject, "number", FEATURE_COLS);
        this.supplementForCheckbox(jsonObject, TPL_FILENAME, 1, vo);
        vo.setData(jsonObject);
        if (!jsonObject.containsKey("feature_cols")) {
            logger.warn("feature_cols is not exists!!!");
            return;
        }
        if (!jsonObject.containsKey("ground_truth")) {
            logger.warn("ground_truth key is not exist");
            return;
        }
        List<String> featureCols = jsonObject.getJSONArray("feature_cols").toJavaList(String.class)
                .stream().map(x -> {
                    String[] tmps = x.split("\\.");
                    return tmps[tmps.length - 1];
                }).collect(Collectors.toList());

        jsonObject.put("source_table", input.getJSONObject(0).getString("tableName"));
        List<String> inputCols = input.getJSONObject(0).getJSONArray("tableCols")
                .toJavaList(String.class);
        List<String> inputColumnTypes = input.getJSONObject(0).getJSONArray("columnTypes")
                .toJavaList(String.class);

        JSONArray resultColTypes = new JSONArray();
        JSONArray resultCols = new JSONArray();
        String groundTruth = jsonObject.getString("ground_truth");
        resultCols.addAll(featureCols);
        resultCols.add(groundTruth);
        resultCols.add("predict");
        this.prepareOutputColumnTypes(resultColTypes, featureCols, inputCols, inputColumnTypes);
        String groundTruthType = ToolUtil
                .getSpecColumnType(inputCols, inputColumnTypes, groundTruth);

        resultColTypes.add(groundTruthType);
        resultColTypes.add("double precision");

        JSONArray modelColTypes = new JSONArray();
        JSONArray modelCols = new JSONArray();
        String[] cols = new String[]{
                "std_err",
                "num_missing_rows_skipped",
                "condition_no",
                "r2",
                "coef",
                "p_values",
                "variance_covariance",
                "num_rows_processed",
                "t_stats"
        };
        String[] types = new String[]{"array", "bigint", "double precision", "double precision",
                "array", "array", "array", "bigint", "array"};
        modelColTypes.addAll(Arrays.asList(types));
        modelCols.addAll(Arrays.asList(cols));
        JSONArray output = new JSONArray();
        JSONObject outItem = new JSONObject();
        outItem.put("tableName", tableName);
        outItem.put("tableCols", resultCols);
        outItem
                .put("nodeName", vo.getName() == null ? AlgEnum.LINEARREGRE.toString() : vo.getName());
        outItem.put("columnTypes", resultColTypes);
        this.setSubTypeForOutput(outItem);
        output.add(outItem);
        String modelTable = String
                .format(SqlTemplate.OUT_TABLE_NAME, outTablePrefix + "_model", vo.getPipelineId(),
                        vo.getId());
        jsonObject.put("model_table", modelTable);
        JSONObject modelItem = new JSONObject();
        modelItem.put("tableName", modelTable);
        modelItem.put("tableCols", modelCols);
        modelItem
                .put("nodeName", vo.getName() == null ? AlgEnum.LINEARREGRE.toString() : vo.getName());
        modelItem.put("columnTypes", modelColTypes);
        this.setSubTypeForOutput(modelItem);
        output.add(modelItem);
        jsonObject.put("output", output);
        vo.setData(jsonObject);
    }
}
